import os
import traceback
from pathlib import Path
from typing import Any, Dict, Optional, Union

import torch
from exllamav3 import Cache, Config, Model
from torch.nn import CrossEntropyLoss
from transformers import GenerationConfig, PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast

from modules import shared
from modules.logging_colors import logger

try:
    import flash_attn
except Exception:
    logger.warning('Failed to load flash-attention due to the following error:\n')
    traceback.print_exc()


class Exllamav3HF(PreTrainedModel):
    def __init__(self, model_dir):
        super().__init__(PretrainedConfig())
        self.generation_config = GenerationConfig()

        config = Config.from_directory(model_dir)
        self.ex_model = Model.from_config(config)

        # Calculate the closest multiple of 256 at or above the chosen value
        max_tokens = shared.args.max_seq_len
        if max_tokens % 256 != 0:
            adjusted_tokens = ((max_tokens // 256) + 1) * 256
            logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
            max_tokens = adjusted_tokens

        self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens)

        # Create load parameters dictionary
        load_params = {'progressbar': True}
        if shared.args.gpu_split:
            split = [float(alloc) for alloc in shared.args.gpu_split.split(",")]
            load_params['use_per_device'] = split

        self.ex_model.load(**load_params)
        self.past_seq = None
        self.max_tokens = max_tokens

    def _validate_model_class(self):
        pass

    def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
        pass

    def prepare_inputs_for_generation(self, input_ids, **kwargs):
        return {'input_ids': input_ids, **kwargs}

    @property
    def device(self) -> torch.device:
        return torch.device(0)

    def __call__(self, *args, **kwargs):
        use_cache = kwargs.get('use_cache', True)
        labels = kwargs.get('labels', None)
        past_key_values = kwargs.get('past_key_values', None)

        if len(args) > 0:
            if not shared.args.cfg_cache:
                logger.error("Please enable the cfg-cache option to use CFG with ExLlamav3_HF.")
                return

            input_ids = args[0]
            is_negative = True
            past_seq = self.past_seq_negative
            ex_cache = self.ex_cache_negative
        else:
            input_ids = kwargs['input_ids']
            is_negative = False
            past_seq = self.past_seq
            ex_cache = self.ex_cache

        seq = input_ids[0].tolist()
        if is_negative and past_key_values is not None:
            seq = past_key_values + seq

        seq_tensor = torch.tensor(seq)
        reset = True

        # Make the forward call
        if labels is None:
            if past_seq is not None:
                min_length = min(past_seq.shape[0], seq_tensor.shape[0])
                indices = torch.nonzero(~torch.eq(past_seq[:min_length], seq_tensor[:min_length]))
                if len(indices) > 0:
                    longest_prefix = indices[0].item()
                else:
                    longest_prefix = min_length

                if longest_prefix > 0:
                    reset = False
                    current_len = longest_prefix
                    if len(seq_tensor) - longest_prefix > 1:
                        self.ex_model.forward(
                            input_ids=seq_tensor[longest_prefix:-1].view(1, -1),
                            params={
                                "attn_mode": "flash_attn",
                                "cache": ex_cache,
                                "past_len": longest_prefix,
                                "batch_shape": (1, self.max_tokens)
                            }
                        )

                        current_len = longest_prefix + len(seq_tensor) - longest_prefix - 1

            if reset:
                if len(seq_tensor) > 1:
                    self.ex_model.forward(
                        input_ids=seq_tensor[:-1].view(1, -1),
                        params={
                            "attn_mode": "flash_attn",
                            "cache": ex_cache,
                            "past_len": 0,
                            "batch_shape": (1, self.max_tokens)
                        }
                    )

                    current_len = len(seq_tensor) - 1
                else:
                    current_len = 0

            logits = self.ex_model.forward(
                input_ids=seq_tensor[-1:].view(1, -1),
                params={
                    "attn_mode": "flash_attn",
                    "cache": ex_cache,
                    "past_len": current_len,
                    "batch_shape": (1, self.max_tokens)
                }
            ).to(input_ids.device).float()
        else:
            logits = self.ex_model.forward(
                input_ids=seq_tensor.view(1, -1),
                params={
                    "attn_mode": "flash_attn",
                    "cache": ex_cache,
                    "past_len": 0,
                    "batch_shape": (1, self.max_tokens)
                }
            ).float()

        if is_negative:
            self.past_seq_negative = seq_tensor
        else:
            self.past_seq = seq_tensor

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, logits.shape[-1])
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        return CausalLMOutputWithPast(logits=logits, past_key_values=seq if use_cache else None, loss=loss)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
        assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported"
        if isinstance(pretrained_model_name_or_path, str):
            pretrained_model_name_or_path = Path(pretrained_model_name_or_path)

        pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path)

        return Exllamav3HF(pretrained_model_name_or_path)
